import torch

from crossmodal_fusion import Multimodal_GatedFusion, Multimodal_AttentionFusion,Multimodal_Att_GatedFusion
# 创建一个Multimodal_GatedFusion实例，设置hidden_size为128
# gated_fusion = Multimodal_GatedFusion(hidden_size=4)
# gated_fusion = Multimodal_AttentionFusion(hidden_size=4)
gated_fusion = Multimodal_Att_GatedFusion(hidden_size=4)
# 创建三个输入张量a, b, c，它们的形状为(batch_size, hidden_size)
batch_size = 2
sen_size = 3
hidden_size = 4
a = torch.randn(batch_size, sen_size, hidden_size)
b = torch.randn(batch_size, sen_size, hidden_size)
c = torch.randn(batch_size, sen_size, hidden_size)

print(a, b, c)
# 使用Multimodal_GatedFusion实例处理输入张量a, b, c
result = gated_fusion(a, b, c)


# print(result)
# 输出结果张量的形状为(batch_size, hidden_size)
print(result.shape)